import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.optical_flow import raft
from torchvision.models.optical_flow._utils import make_coords_grid, upsample_flow
from torch import Tensor

class RPO_RAFT(nn.Module):

    def __init__(self, context_encoder_units=384, motion_encoder_corr_layers=(96,), motion_encoder_flow_layers=(64, 32), motion_encoder_out_channels=82, 
                 recurrent_block_hidden_state_size=96, recurrent_block_kernel_size=(3,), recurrent_block_padding=(1,), flow_head_hidden_size=128,):
        super().__init__()
        self.corr_block = raft.CorrBlock()
        motion_encoder = raft.MotionEncoder(
            in_channels_corr = self.corr_block.out_channels,
            corr_layers=motion_encoder_corr_layers,
            flow_layers=motion_encoder_flow_layers,
            out_channels=motion_encoder_out_channels,
        )
        out_channels_context = context_encoder_units - recurrent_block_hidden_state_size
        recurrent_block = raft.RecurrentBlock(
            input_size=motion_encoder.out_channels + out_channels_context,
            hidden_size=recurrent_block_hidden_state_size,
            kernel_size=recurrent_block_kernel_size,
            padding=recurrent_block_padding,
        )
        flow_head = raft.FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size)
        self.update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
        self.mask_predictor = raft.MaskPredictor(
            in_channels=recurrent_block_hidden_state_size,
            hidden_size=256,
            multiplier=0.25,  
        )

    def forward(self, fmaps, image1, num_flow_updates: int = 12):
        batch_size, _, h, w = image1.shape
        
        fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
        if fmap1.shape[-2:] != (h // 8, w // 8):
            raise ValueError("The feature encoder should downsample H and W by 8")
        self.corr_block.build_pyramid(fmap1, fmap2)
        
        if fmap1.shape[-2:] != (h // 8, w // 8): ##context_out
            raise ValueError("The context encoder should downsample H and W by 8")
        hidden_state_size = self.update_block.hidden_state_size
        out_channels_context = fmap1.shape[1] - hidden_state_size ##context_out
        if out_channels_context <= 0:
            raise ValueError(
                f"The context encoder outputs {fmap1.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels" ##context_out
            )
        
        hidden_state, context = torch.split(fmap1, [hidden_state_size, out_channels_context], dim=1) ##context_out
        hidden_state = torch.tanh(hidden_state)

        context = F.relu(context)
        coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
        coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)

        flow_predictions = []
        for _ in range(num_flow_updates):
            coords1 = coords1.detach()  # Don't backpropagate gradients through this branch, see paper
            corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)

            flow = coords1 - coords0
            hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)

            coords1 = coords1 + delta_flow

            up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
            upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
            flow_predictions.append(upsampled_flow)
        return flow_predictions
